import torch

"""
配置文件 - 包含所有可配置的参数
"""

class Config:
    # 数据集配置
    DATASET = "CIFAR100"
    SUPERCLASS = True  # 使用超类分类
    BATCH_SIZE = 128
    NUM_WORKERS = 4
    
    # 模型配置
    MODEL_NAME = "resnet50"  # 使用ResNet50作为基础模型
    PRETRAINED = True       # 使用预训练权重
    NUM_CLASSES = 20        # CIFAR100有20个超类
    
    # 训练配置
    EPOCHS = 20
    LEARNING_RATE = 0.001
    MOMENTUM = 0.9
    WEIGHT_DECAY = 1e-4
    LR_SCHEDULER = True     # 使用学习率调度器
    
    # 设备配置
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # 路径配置
    SAVE_DIR = "./saved_models"
    LOG_DIR = "./logs"
    
    # 其他
    SEED = 42
    PRINT_FREQ = 100  # 每多少批次打印一次信息